Skip to content

Conversation

@Tcc0403
Copy link
Collaborator

@Tcc0403 Tcc0403 commented Oct 1, 2025

Summary

Add a flag HAS_GRADIENTS to cross entropy kernel. No more gradients computation if there's no need.

Testing Done

Cross Entropy forward with no_grad
image

Fused Linear Cross Entropy forward with no_grad
image

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403 Tcc0403 changed the title feat(ce,-flce): decouple gradients computation for no_grad mode feat(ce,flce): decouple gradients computation for no_grad mode Oct 1, 2025
@Tcc0403 Tcc0403 force-pushed the tcc/flce-eval-no-grad branch from 291141c to 9ab603a Compare October 1, 2025 13:40
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
@Tcc0403 Tcc0403 requested review from momochen and shimizust October 8, 2025 10:53
Copy link
Collaborator

@shimizust shimizust left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding this. The forward flce still is significantly slower than hf since we're still computing grad_input applying token scaling logic? Also do you know why fp32 accum is faster?

@lancerts lancerts merged commit 5c2a04d into main Oct 11, 2025
2 checks passed
@lancerts lancerts deleted the tcc/flce-eval-no-grad branch October 11, 2025 15:44
@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Oct 11, 2025

@shimizust The slower forward pass is kinda expected because:

  1. instead of one big matmul, we slice input rowwise and do multiple matmuls. It brings some kernel launch overhead. Plus, it's not gauranteed to ivnoke the most efficient kernels (tiling size, tail effect, etc... from kernel's prespective).
  2. multiple cross entropy kernel launch overhead and similar issue also add up.

I just found that we can remove this line in eval mode as well, cutting another matmul for each interation should be significant.

grad_input[start_idx:end_idx] = grad_logits_chunk @ weight

and some grad tensors allocations too

shimizust pushed a commit that referenced this pull request Oct 16, 2025
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
follow-up #894 


<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Tcc0403 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants